Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainable Tokens: Support for Weight Tying #2399

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

githubnemo
Copy link
Collaborator

@githubnemo githubnemo commented Feb 25, 2025

This is a follow-up PR of #2376 to add support for weight-tying. Do not merge before the other is not merged.

What is this

Some models, such as gpt2, tie the weights between the LM head and the input embeddings for various reasons. If we use the trainable tokens adapter, we're changing the result of the forward() of the input embeddings but we do not change the weights (unless we merge()). This means that the changes are not reflected in the tied weights, such as the LM head, leading to wrong results when training.

How it is solved

The current approach is searching for tied layers and putting TrainableTokensLayer adapters on them as well but initialized to use the parameters from the embedding layer's TrainableTokensLayer. This is done via the tied_adapter argument of TrailableTokensLayer.__init__().

What needs to be done

  • encoder-decoder model tests
  • support for standalone TrainableTokens adapter
  • more tests

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@githubnemo githubnemo force-pushed the feature/custom-token-tuner-weight-tying branch from 69948b9 to ac70db6 Compare February 26, 2025 16:00
nemo added 11 commits February 26, 2025 17:21
Notably we are removing the duplication filter of `named_modules` when searching for
the (tied) target modules since tied weights are by definition duplicates.
It's now possible to let the adapter decide which is the input embedding layer based on the output
of `model.get_input_embeddings()`. If that fails, the default is still `embed_tokens`.
This is probably just a case of model misconfiguration but there are cases in the tests
where tie_embedding_weights is set to true in the config but no tied_weights_keys is set on the model.
Before this change only the selection of the module that was supposed to have the queried
attribute was given to the wrapper implemention (via `_{has,get}attr_wrapped`). Now the full
`getattr()` call is done by the implementation.

This change is motivated by the need for access to `embedding.weight` at certain times which,
for `ModulesToSaveWrapper` is not a problem - but it is for `TrainableTokensWrapper` since
the original module's weights differ from the current weights, at least potentially.

What we do now is to merge the weights and return those when `embedding.weight` is accessed.
No other attributes are currently forwarded.
Mixed batch is still broken, though.
Looking at you, stable diffusion
@githubnemo githubnemo marked this pull request as ready for review March 3, 2025 15:43
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for trainable tokens with tied embeddings and enhancing the tests. This was more complex than I expected. Good work covering this many edge cases.

I have a couple of comments, but I think there is nothing major.

self.model.get_input_embeddings(), TrainableTokensWrapper
if (
model_config.get("tie_word_embeddings", False)
and self.model._tied_weights_keys is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a comment why this check is required?

found, `embed_tokens`). Alternatively, you can specify a dictionary where the key is the name of the
embedding module and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note
that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. Also note that
models using weight-tying are currently not supported.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjust/delete?

self.trainable_tokens_original = BufferDict({})
self.token_indices = {}
else:
self.trainable_tokens_delta = self.tied_adapter.trainable_tokens_delta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, you mentioned that stuff like model.state_dict() already correctly works for this case thanks to transformers. Could you add a comment for that? When the model is not a transformers model, would we get the same param twice?


# Mark the weight as unmerged
self.merged_adapters = []

def update_layer(self, adapter_name, **kwargs):
if kwargs.get("tied_adapter", None):
# in this case we don't have any say, we're just following whatever the tied
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# in this case we don't have any say, we're just following whatever the tied
# in this case we don't have any because we're just following whatever the tied

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to express that, as a tied layer, we don't have anything to do but to return. I'll clarify.

scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
elif isinstance(self.base_layer, torch.nn.Linear):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not necessarily work with quantized models, right? I wonder if we can find a more robust way of handling this, but I'm not sure how exactly.

emb_in = peft_model.model.encoder.embed_tokens(torch.tensor([token_indices]))
emb_out = peft_model.model.lm_head(1 / emb_in)

assert all(torch.diag(emb_out[0]) == torch.tensor([emb_dim] * len(token_indices)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument as above

[
("model_emb", lambda model: model.emb),
("model_embed_in", lambda model: model.embed_in),
("model", lambda model: model.model.model.embed_tokens),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be more prudent to use operator.attrgetter than lambda but maybe it's unproblematic here.

@@ -655,6 +675,18 @@ def unload_and_optionally_merge_module(
return self.token_adapter.get_base_layer()


def _get_input_embeddings_name(model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function could get a default argument that it returns instead of None, like getattr, but no strong opinion.

@@ -263,6 +274,32 @@ def check_config_json(self, tmp_dirname, model):
if hasattr(model, "config"): # custom models don't have a config attribute
assert config["base_model_name_or_path"] == model.config.to_dict()["_name_or_path"]

def perturb_trainable_token_weights_if_used(self, model, config_kwargs, adapter_name="default", weight=1.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def perturb_trainable_token_weights_if_used(self, model, config_kwargs, adapter_name="default", weight=1.0):
def perturb_trainable_token_weights_if_used(self, model, config_kwargs, adapter_name="default", scale=1.0):

Maybe more precise name?

model = get_peft_model(model, config)
model = model.to(self.torch_device)

self.perturb_trainable_token_weights_if_used(model, config_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, a bit unfortunate that this is necessary. WDYT about using an auto fixture that monkey patches TrainableTokensLayer.update_layer so that it calls "super" but then perturbs the weight?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants